# utils.py
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def gaussian_gram_matrix(emb, sigma=1.0):
    """
    Compute the Gaussian Gram matrix from an embedding tensor.
    """
    diff = emb.unsqueeze(0) - emb.unsqueeze(1)
    diff_test = diff.pow(2).mean(dim=2)
    dist_sq = diff.pow(2).mean(dim=2).sum(dim=2)
    gram = torch.exp(-dist_sq / (2 * sigma**2))
    return gram

def compute_von_neumann_entropy_from_density(rho):
    """
    Compute the von Neumann entropy for a normalized density matrix.
    """
    eigenvalues, _ = torch.linalg.eig(rho)
    eigenvalues = eigenvalues.real
    eps = 1e-12
    valid = eigenvalues > eps
    entropy = -torch.sum(eigenvalues[valid] * torch.log(eigenvalues[valid]))
    return entropy.item()

def compute_joint_entropy(gram1, gram2):
    """
    Compute the joint entropy from two normalized Gram matrices.
    """
    joint_gram = (gram1 / torch.trace(gram1)) * (gram2 / torch.trace(gram2))
    joint_gram_norm = joint_gram / torch.trace(joint_gram)
    return compute_von_neumann_entropy_from_density(joint_gram_norm)
